/*
 * @Author: gaoxinglong
 * @Date: 2022-11-24 15:10:18
 * @LastEditTime: 2022-11-30 20:05:30
 * @LastEditors: gaoxinglong
 */
#include "online_streaming.h"
#include <algorithm>
#include <memory>
#include <utility>

OnlineStreaming::OnlineStreaming(const wenet::StreamingAsrConfig &asr_config)
    : input_dim_(asr_config.input_dim), left_context_length_(asr_config.N_l),
      right_context_length_(asr_config.N_c),
      conv_context_length_(asr_config.conv_context) {}

OnlineStreaming::~OnlineStreaming() {}

void OnlineStreaming::ReceiveNewChunk(const std::vector<float> &new_chunk,
                                      bool is_last) {
  for (size_t i = 0; i < new_chunk.size(); ++i) {
    frame_buffer_.push_back(new_chunk[i]);
  }
  input_finished_ = is_last;
  accum_frames_ = frame_buffer_.size() / input_dim_;
  std::cout << "streaming_features_, get new features, accum_frames_: "
            << accum_frames_ << std::endl;
}

void OnlineStreaming::GetBlock(std::vector<float> &block_output) {
  block_output.clear();
  auto self_define_minimum = [&](const int32_t a, const int32_t b) {
    if (a > b) {
      return b;
    }
    return a;
  };
  temp_start_ = std::max(0, offset_ - conv_context_length_);

  if (input_finished_) {
    temp_end_ =
        self_define_minimum(offset_ + current_context_length_ +
                                left_context_length_ + conv_context_length_,
                            accum_frames_);
    fprintf(stderr, "temp_start_:%d, temp_end_:%d\n", temp_start_, temp_end_);
    temp_block_frame_size_ = temp_end_ - temp_start_;
    cnn_lookback_ = temp_start_ > 0;
    cnn_lookahead_ = temp_end_ < accum_frames_;
    int32_t block_size = temp_block_frame_size_ * input_dim_;
    block_output.resize(block_size);
    std::copy(frame_buffer_.begin() + offset_,
              frame_buffer_.begin() + block_size, block_output.begin());
    offset_ += current_context_length_;
  } else {
    int32_t ideal_block_end = offset_ + current_context_length_ +
                              left_context_length_ + conv_context_length_;
    std::cout << "ideal_block_end:" << ideal_block_end << std::endl;
    if (ideal_block_end > accum_frames_) {
      // not enough data to decode
      return;
    }
    // temp_start_ = std::max(0, offset_ - conv_context_length_);

    temp_end_ =
        self_define_minimum(offset_ + current_context_length_ +
                                left_context_length_ + conv_context_length_,
                            accum_frames_);
    fprintf(stderr, "temp_start_:%d, temp_end_:%d\n", temp_start_, temp_end_);
    temp_block_frame_size_ = temp_end_ - temp_start_;
    cnn_lookahead_ = temp_end_ < accum_frames_;
    int32_t block_size = temp_block_frame_size_ * input_dim_;
    block_output.resize(block_size);
    std::copy(frame_buffer_.begin() + offset_,
              frame_buffer_.begin() + block_size, block_output.begin());
    offset_ += current_context_length_;
  }
}